{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Description:\n",
    "这里我们尝试建立一个PNN网络来完成一个ctr预测的问题。 关于Pytorch的建模流程， 主要有四步：\n",
    "1. 准备数据\n",
    "2. 建立模型\n",
    "3. 训练模型\n",
    "4. 使用和保存\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T03:05:48.187781Z",
     "start_time": "2020-10-17T03:05:48.178807Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset, TensorDataset\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from torchkeras import summary, Model\n",
    "\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T02:05:32.132381Z",
     "start_time": "2020-10-17T02:05:32.107448Z"
    }
   },
   "outputs": [],
   "source": [
    "# 导入数据， 数据已经处理好了 preprocess/下\n",
    "train_set = pd.read_csv('preprocessed_data/train_set.csv')\n",
    "val_set = pd.read_csv('preprocessed_data/val_set.csv')\n",
    "test_set = pd.read_csv('preprocessed_data/test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T02:05:37.516289Z",
     "start_time": "2020-10-17T02:05:37.450464Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>I1</th>\n",
       "      <th>I2</th>\n",
       "      <th>I3</th>\n",
       "      <th>I4</th>\n",
       "      <th>I5</th>\n",
       "      <th>I6</th>\n",
       "      <th>I7</th>\n",
       "      <th>I8</th>\n",
       "      <th>I9</th>\n",
       "      <th>I10</th>\n",
       "      <th>I11</th>\n",
       "      <th>I12</th>\n",
       "      <th>I13</th>\n",
       "      <th>C1</th>\n",
       "      <th>C2</th>\n",
       "      <th>C3</th>\n",
       "      <th>C4</th>\n",
       "      <th>C5</th>\n",
       "      <th>C6</th>\n",
       "      <th>C7</th>\n",
       "      <th>C8</th>\n",
       "      <th>C9</th>\n",
       "      <th>C10</th>\n",
       "      <th>C11</th>\n",
       "      <th>C12</th>\n",
       "      <th>C13</th>\n",
       "      <th>C14</th>\n",
       "      <th>C15</th>\n",
       "      <th>C16</th>\n",
       "      <th>C17</th>\n",
       "      <th>C18</th>\n",
       "      <th>C19</th>\n",
       "      <th>C20</th>\n",
       "      <th>C21</th>\n",
       "      <th>C22</th>\n",
       "      <th>C23</th>\n",
       "      <th>C24</th>\n",
       "      <th>C25</th>\n",
       "      <th>C26</th>\n",
       "      <th>Label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000381</td>\n",
       "      <td>0.000473</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.009449</td>\n",
       "      <td>0.082147</td>\n",
       "      <td>0.004825</td>\n",
       "      <td>0.003656</td>\n",
       "      <td>0.040447</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.081081</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.001299</td>\n",
       "      <td>0</td>\n",
       "      <td>236</td>\n",
       "      <td>331</td>\n",
       "      <td>326</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>64</td>\n",
       "      <td>17</td>\n",
       "      <td>1</td>\n",
       "      <td>518</td>\n",
       "      <td>95</td>\n",
       "      <td>1207</td>\n",
       "      <td>586</td>\n",
       "      <td>2</td>\n",
       "      <td>600</td>\n",
       "      <td>526</td>\n",
       "      <td>3</td>\n",
       "      <td>166</td>\n",
       "      <td>116</td>\n",
       "      <td>2</td>\n",
       "      <td>14</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>60</td>\n",
       "      <td>27</td>\n",
       "      <td>327</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000127</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.075768</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000710</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>33</td>\n",
       "      <td>51</td>\n",
       "      <td>1158</td>\n",
       "      <td>150</td>\n",
       "      <td>11</td>\n",
       "      <td>4</td>\n",
       "      <td>760</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>410</td>\n",
       "      <td>295</td>\n",
       "      <td>1203</td>\n",
       "      <td>167</td>\n",
       "      <td>13</td>\n",
       "      <td>75</td>\n",
       "      <td>50</td>\n",
       "      <td>8</td>\n",
       "      <td>145</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1166</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>469</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000381</td>\n",
       "      <td>0.000236</td>\n",
       "      <td>0.137931</td>\n",
       "      <td>0.004804</td>\n",
       "      <td>0.030185</td>\n",
       "      <td>0.007841</td>\n",
       "      <td>0.062157</td>\n",
       "      <td>0.024126</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.054054</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.015584</td>\n",
       "      <td>33</td>\n",
       "      <td>122</td>\n",
       "      <td>829</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "      <td>223</td>\n",
       "      <td>17</td>\n",
       "      <td>1</td>\n",
       "      <td>573</td>\n",
       "      <td>699</td>\n",
       "      <td>899</td>\n",
       "      <td>290</td>\n",
       "      <td>2</td>\n",
       "      <td>266</td>\n",
       "      <td>1066</td>\n",
       "      <td>8</td>\n",
       "      <td>99</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>588</td>\n",
       "      <td>0</td>\n",
       "      <td>11</td>\n",
       "      <td>434</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.011696</td>\n",
       "      <td>0.000473</td>\n",
       "      <td>0.034483</td>\n",
       "      <td>0.002180</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.032907</td>\n",
       "      <td>0.003193</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.003896</td>\n",
       "      <td>33</td>\n",
       "      <td>12</td>\n",
       "      <td>147</td>\n",
       "      <td>79</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>1105</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>534</td>\n",
       "      <td>631</td>\n",
       "      <td>412</td>\n",
       "      <td>487</td>\n",
       "      <td>13</td>\n",
       "      <td>273</td>\n",
       "      <td>894</td>\n",
       "      <td>6</td>\n",
       "      <td>188</td>\n",
       "      <td>34</td>\n",
       "      <td>1</td>\n",
       "      <td>862</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>67</td>\n",
       "      <td>27</td>\n",
       "      <td>380</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.004450</td>\n",
       "      <td>0.001064</td>\n",
       "      <td>0.034483</td>\n",
       "      <td>0.006119</td>\n",
       "      <td>0.039457</td>\n",
       "      <td>0.001206</td>\n",
       "      <td>0.009141</td>\n",
       "      <td>0.000887</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.027027</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.003896</td>\n",
       "      <td>56</td>\n",
       "      <td>17</td>\n",
       "      <td>938</td>\n",
       "      <td>555</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>943</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>194</td>\n",
       "      <td>866</td>\n",
       "      <td>17</td>\n",
       "      <td>211</td>\n",
       "      <td>12</td>\n",
       "      <td>105</td>\n",
       "      <td>323</td>\n",
       "      <td>4</td>\n",
       "      <td>348</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>296</td>\n",
       "      <td>0</td>\n",
       "      <td>9</td>\n",
       "      <td>23</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    I1        I2        I3        I4        I5        I6        I7        I8  ...  C20   C21  C22  C23  C24  C25  C26  Label\n",
       "0  0.0  0.000381  0.000473  0.000000  0.009449  0.082147  0.004825  0.003656  ...    2    14    0    0   60   27  327      0\n",
       "1  0.0  0.000127  0.000000  0.000000  0.075768  0.000000  0.000000  0.000000  ...    0  1166    0    1  469    0    0      0\n",
       "2  0.0  0.000381  0.000236  0.137931  0.004804  0.030185  0.007841  0.062157  ...    0   588    0   11  434    0    0      0\n",
       "3  0.0  0.011696  0.000473  0.034483  0.002180  0.000000  0.000000  0.032907  ...    1   862    0    0   67   27  380      1\n",
       "4  0.0  0.004450  0.001064  0.034483  0.006119  0.039457  0.001206  0.009141  ...    0   296    0    9   23    0    0      0\n",
       "\n",
       "[5 rows x 40 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_set.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T02:06:19.519130Z",
     "start_time": "2020-10-17T02:06:19.498186Z"
    }
   },
   "outputs": [],
   "source": [
    "# 这里需要把特征分成数值型和离散型， 因为后面的模型里面离散型的特征需要embedding， 而数值型的特征直接进入了stacking层， 处理方式会不一样\n",
    "data_df = pd.concat((train_set, val_set, test_set))\n",
    "\n",
    "dense_feas = ['I'+str(i) for i in range(1, 14)]\n",
    "sparse_feas = ['C'+str(i) for i in range(1, 27)]\n",
    "\n",
    "# 定义一个稀疏特征的embedding映射， 字典{key: value}, key表示每个稀疏特征， value表示数据集data_df对应列的不同取值个数， 作为embedding输入维度\n",
    "sparse_feas_map = {}\n",
    "for key in sparse_feas:\n",
    "    sparse_feas_map[key] = data_df[key].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T02:06:34.393602Z",
     "start_time": "2020-10-17T02:06:34.386578Z"
    }
   },
   "outputs": [],
   "source": [
    "feature_info = [dense_feas, sparse_feas, sparse_feas_map]  # 这里把特征信息进行封装， 建立模型的时候作为参数传入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T02:11:55.192585Z",
     "start_time": "2020-10-17T02:11:55.157720Z"
    }
   },
   "outputs": [],
   "source": [
    "# 把数据构建成数据管道\n",
    "dl_train_dataset = TensorDataset(torch.tensor(train_set.drop(columns='Label').values).float(), torch.tensor(train_set['Label']).float())\n",
    "dl_val_dataset = TensorDataset(torch.tensor(val_set.drop(columns='Label').values).float(), torch.tensor(val_set['Label']).float())\n",
    "\n",
    "dl_train = DataLoader(dl_train_dataset, shuffle=True, batch_size=16)\n",
    "dl_val = DataLoader(dl_val_dataset, shuffle=True, batch_size=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T02:12:52.219078Z",
     "start_time": "2020-10-17T02:12:52.207109Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 39]) torch.Size([16])\n"
     ]
    }
   ],
   "source": [
    "# 查看一下数据\n",
    "for b in iter(dl_train):\n",
    "    print(b[0].shape, b[1].shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立模型\n",
    "建立模型有三种方式：\n",
    "1. 继承nn.Module基类构建自定义模型\n",
    "2. nn.Sequential按层顺序构建模型\n",
    "3. 继承nn.Module基类构建模型， 并辅助应用模型容器进行封装\n",
    "\n",
    "这里我们依然会使用第三种方式， 因为embedding依然是很多层。 模型的结构如下：\n",
    "\n",
    "![](img/pnn.png)\n",
    "\n",
    "这里简单的分析一下这个模型， 说几个比较重要的细节：\n",
    "1. 这里的输入， 由于都进行了embedding， 所以这里应该是类别型的特征， 关于数值型的特征， 在把类别都交叉完了之后， 才把数值型的特征加入进去\n",
    "2. 交叉层这里， 左边和右边其实用的同样的一层， 有着同样的神经单元个数， 只不过这里进行计算的时候， 得分开算，左边的是单个特征的线性组合lz， 而右边是两两特征进行交叉后特征的线性组合lp。 得到这两个之后， 两者进行相加得到最终的组合， 然后再relu激活才是交叉层的输出。\n",
    "3. 交叉层这里图上给出的是**一个神经元**内部的计算情况， 注意这里是一个神经元内部的计算， 这些圈不是表示多个神经元。\n",
    "\n",
    "下面说一下代码的逻辑：\n",
    "1. 首先， 我们定义一个DNN神经网络， 这个也就是上面图片里面的交叉层上面的那一部分结构， 也就是很多个全连接层的一个网络， 之所以单独定义这样的一个网络， 是因为更加的灵活， 加多少层， 每层神经元个数是多少我们就可以自己指定了， 这里会涉及到一个小操作技巧。\n",
    "2. 然后就是定义整个PNN网络， 核心部分就是在前向传播。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T02:44:59.360643Z",
     "start_time": "2020-10-17T02:44:59.346968Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(256, 128), (128, 64)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# python生成元素对测试， 这个可以帮助我们定义一个列表的全连接层\n",
    "a = [256, 128, 64]\n",
    "list(zip(a[:-1], a[1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T03:04:47.201126Z",
     "start_time": "2020-10-17T03:04:47.180184Z"
    }
   },
   "outputs": [],
   "source": [
    "# 定义一个全连接层的神经网络\n",
    "class DNN(nn.Module):\n",
    "    \n",
    "    def __init__(self, hidden_units, dropout=0.):\n",
    "        \"\"\"\n",
    "        hidden_units:列表， 每个元素表示每一层的神经单元个数，比如[256, 128, 64]，两层网络， 第一层神经单元128个，第二层64，注意第一个是输入维度\n",
    "        dropout: 失活率\n",
    "        \"\"\"\n",
    "        super(DNN, self).__init__()\n",
    "        \n",
    "        # 下面创建深层网络的代码 由于Pytorch的nn.Linear需要的输入是(输入特征数量， 输出特征数量)格式， 所以我们传入hidden_units， \n",
    "        # 必须产生元素对的形式才能定义具体的线性层， 且Pytorch中线性层只有线性层， 不带激活函数。 这个要与tf里面的Dense区分开。\n",
    "        self.dnn_network = nn.ModuleList([nn.Linear(layer[0], layer[1]) for layer in list(zip(hidden_units[:-1], hidden_units[1:]))])\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "    \n",
    "    # 前向传播中， 需要遍历dnn_network， 不要忘了加激活函数\n",
    "    def forward(self, x):\n",
    "        \n",
    "        for linear in self.dnn_network:\n",
    "            x = linear(x)\n",
    "            x = F.relu(x)\n",
    "        \n",
    "        x = self.dropout(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T03:06:58.583895Z",
     "start_time": "2020-10-17T03:06:58.571928Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Linear-1                    [-1, 8]             136\n",
      "            Linear-2                    [-1, 4]              36\n",
      "            Linear-3                    [-1, 2]              10\n",
      "            Linear-4                    [-1, 1]               3\n",
      "           Dropout-5                    [-1, 1]               0\n",
      "================================================================\n",
      "Total params: 185\n",
      "Trainable params: 185\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.000061\n",
      "Forward/backward pass size (MB): 0.000122\n",
      "Params size (MB): 0.000706\n",
      "Estimated Total Size (MB): 0.000889\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# 测试一下这个网络\n",
    "hidden_units = [16, 8, 4, 2, 1]        # 层数和每一层神经单元个数， 由我们自己定义了\n",
    "dnn = DNN(hidden_units)\n",
    "summary(dnn, input_shape=(16,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 下面我们定义真正的PNN网络， 上半部分已经搞定， 下半部分比较难弄的就是交叉层这里的计算了\n",
    "# 这里的逻辑是底层输入（类别型特征) -> embedding层 -> product 层\n",
    "class PNN(nn.Module):\n",
    "    \n",
    "    def __init__(self, feature_info, hidden_units, mode='in', dnn_dropout=0., embed_dim=10, outdim=1):\n",
    "        \"\"\"\n",
    "        DeepCrossing：\n",
    "            feature_info: 特征信息（数值特征， 类别特征， 类别特征embedding映射)\n",
    "            hidden_units: 列表， 全连接层的每一层神经单元个数， 这里注意一下， 第一层神经单元个数实际上是hidden_units[1]， 因为hidden_units[0]是输入层\n",
    "            dropout: Dropout层的失活比例\n",
    "            embed_dim: embedding的维度m\n",
    "            outdim: 网络的输出维度\n",
    "        \"\"\"\n",
    "        super(PNN, self).__init__()\n",
    "        self.dense_feas, self.sparse_feas, self.sparse_feas_map = feature_info\n",
    "        self.field_num = len(self.sparse_feas)\n",
    "        self.mode = mode\n",
    "        self.embed_dim = embed_dim\n",
    "        \n",
    "         # embedding层， 这里需要一个列表的形式， 因为每个类别特征都需要embedding\n",
    "        self.embed_layers = nn.ModuleDict({\n",
    "            'embed_' + str(key): nn.Embedding(num_embeddings=val, embedding_dim=self.embed_dim)\n",
    "            for key, val in self.sparse_feas_map.items()\n",
    "        })\n",
    "        \n",
    "        # product层， 由于交叉这里分为两部分， 一部分是单独的特征运算， 也就是上面结构的z部分， 一个是两两交叉， p部分， 而p部分还分为了内积交叉和外积交叉\n",
    "        # 所以， 这里需要自己定义参数张量进行计算\n",
    "        # z部分的w， 这里的神经单元个数是hidden_units[0], 上面我们说过， 全连接层的第一层神经单元个数是hidden_units[1]， 而0层是输入层的神经\n",
    "        # 单元个数， 正好是product层的输出层  关于维度， 这个可以看在博客中的分析\n",
    "        self.w_z = torch.rand([self.field_num, self.embed_dim, hidden_units[0]], requires_grad=True)\n",
    "        \n",
    "        # p部分, 分内积和外积两种操作\n",
    "        if self.mode == 'in':\n",
    "            self.w_p = torch.rand([self.field_num, self.field_num, hidden_units[0]], requires_grad=True)\n",
    "        else:\n",
    "            self.w_p = torch.rand([self.embed_dim, self.embed_dim, hidden_units[0]], requires_grad=True)\n",
    "        \n",
    "        self.l_b = torch.rand([hidden_units[0], ], requires_grad=True)\n",
    "        \n",
    "        # dnn 层\n",
    "        self.dnn_network = DNN(hidden_units[1:], dnn_dropout)\n",
    "        self.dense_final = nn.Linear(hidden_units[-1], 1)\n",
    "    \n",
    "    \n",
    "    def forward(self, x):\n",
    "        dense_inputs, sparse_inputs = x[:, :13], x[:, 13:]   # 数值型和类别型数据分开\n",
    "        sparse_inputs = sparse_inputs.long()      # 需要转成长张量， 这个是embedding的输入要求格式\n",
    "        sparse_embeds = [self.embed_layers['embed_'+key](sparse_inputs[:, i]) for key, i in zip(self.sparse_feas_map.keys(), range(sparse_inputs.shape[1]))]   \n",
    "        # 上面这个sparse_embeds的维度是 [field_num, None, embed_dim] \n",
    "        sparse_embeds = sparse_embeds.permute((1, 0, 2))   # [None, field_num, embed_dim]\n",
    "        z = sparse_embeds\n",
    "        \n",
    "        # product layer\n",
    "        \n",
    "        lz = torch.mm(z.view(z.shape[0], -1), self.w_z.permute((2, 0, 1).view(self.w_z.shape[2]).T)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T15:39:46.115782Z",
     "start_time": "2020-10-17T15:39:46.106769Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-1.4233, -0.5386,  0.8526,  0.2451, -1.8669],\n",
       "         [ 1.1324, -1.6918,  0.5853, -1.8199,  0.4612],\n",
       "         [ 1.4145, -1.1161, -1.2833,  0.3778,  0.4652]],\n",
       "\n",
       "        [[-2.1107,  0.1400,  1.4397,  1.7950,  0.0228],\n",
       "         [-0.2800, -1.2522, -0.7620,  1.5946,  0.7884],\n",
       "         [-0.1130, -0.9751,  0.9359,  2.8937, -0.0386]]])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.randn([2, 3, 5])\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T15:39:50.241803Z",
     "start_time": "2020-10-17T15:39:50.223851Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.4233, -0.5386,  0.8526,  0.2451, -1.8669,  1.1324, -1.6918,  0.5853,\n",
       "         -1.8199,  0.4612,  1.4145, -1.1161, -1.2833,  0.3778,  0.4652],\n",
       "        [-2.1107,  0.1400,  1.4397,  1.7950,  0.0228, -0.2800, -1.2522, -0.7620,\n",
       "          1.5946,  0.7884, -0.1130, -0.9751,  0.9359,  2.8937, -0.0386]])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "   # 2* 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T15:41:46.422959Z",
     "start_time": "2020-10-17T15:41:46.412941Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.6515e-01],\n",
       "         [-2.5681e-02],\n",
       "         [ 4.3027e-01],\n",
       "         [-2.8399e-01],\n",
       "         [-1.2000e+00]],\n",
       "\n",
       "        [[ 1.4511e-03],\n",
       "         [-4.3237e-02],\n",
       "         [ 4.9474e-02],\n",
       "         [ 1.1974e+00],\n",
       "         [-1.7887e+00]],\n",
       "\n",
       "        [[ 5.3539e-01],\n",
       "         [-1.1398e+00],\n",
       "         [-6.2405e-01],\n",
       "         [-9.5267e-01],\n",
       "         [ 1.6035e-01]]])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = torch.randn([3, 5, 1])\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T15:51:18.214876Z",
     "start_time": "2020-10-17T15:51:18.201910Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.9607],\n",
       "        [-2.0508]])"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.mm(a.view(2, -1), b.permute((2, 0, 1)).view(1, -1).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-17T15:21:52.613990Z",
     "start_time": "2020-10-17T15:21:52.601980Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -1.3029,   7.8991],\n",
       "        [ -0.0210, -10.4559],\n",
       "        [  3.9463,   7.3142],\n",
       "        [  1.0481,   6.3338],\n",
       "        [  0.6330,  13.7092],\n",
       "        [ -2.9125,   0.8828],\n",
       "        [  3.8470,  -0.7602],\n",
       "        [ -6.2366,  -2.2087],\n",
       "        [  2.5922,   5.1750],\n",
       "        [  1.6220,   0.9939]])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "print('a', a)\n",
    "b = torch.randn([3, 5, 2])\n",
    "print('b', b)\n",
    "a = a.view(10, -1)\n",
    "print(a)\n",
    "b = b.view(-1, 2)\n",
    "torch.matmul(a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
